{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy\n",
    "\n",
    "# fids = numpy.load('/data/vision/torralba/scratch2/junyanz/gandissect/fid_results/bedroom_10000_s0_e512.npy')\n",
    "fids = numpy.load('/data/vision/torralba/scratch2/junyanz/gandissect/fid_results/bedroom_s0_e512.npy')\n",
    "\n",
    "manual_units = [457, 231, 63, 147, 112, 438, 304, 306, 271, 293, 511, 435, 333, 289, 163, 25, 23, 30, 9, 151]\n",
    "fid_units = [231, 112, 457, 151, 63, 511, 9, 224, 230, 43, 195, 288, 342, 350, 123, 147, 25, 438, 56, 100]\n",
    "merged_units = manual_units + [u for u in fid_units if u not in manual_units]\n",
    "merged_units"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4,3))\n",
    "plt.hist(fids, bins=numpy.linspace(10,48,39), label='all layer4 units')\n",
    "plt.hist(fids[manual_units], bins=numpy.linspace(10,48,39), label='manually flagged')\n",
    "plt.legend()\n",
    "plt.ylabel('Number of units')\n",
    "plt.xlabel('FID score (lower is better)')\n",
    "plt.title('Per-Unit FID scores')\n",
    "plt.tight_layout()\n",
    "plt.savefig('artifact-histogram.svg')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4,3))\n",
    "plt.hist(fids, bins=numpy.linspace(10,48,39), label='all layer4 units')\n",
    "plt.hist(fids[manual_units], bins=numpy.linspace(10,48,39), label='manually flagged')\n",
    "plt.legend()\n",
    "plt.ylabel('Number of units')\n",
    "plt.xlabel('FID score (lower is better)')\n",
    "plt.title('Per-Unit FID scores')\n",
    "plt.tight_layout()\n",
    "plt.savefig('artifact-histogram.svg')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Interesting cases: the worst-FID manual units\n",
    "manual_units[numpy.argmin(fids[manual_units])], fids[435]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Interesting cases: the worst-FID unit overall\n",
    "numpy.argmax(fids), fids[231]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the worst-FID unit missed by manual inspection\n",
    "mask = numpy.ones_like(fids)\n",
    "mask[manual_units] = 0\n",
    "numpy.argmax(fids * mask), fids[230]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}